{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "!pip install -Uqq fastbook\n",
    "import fastbook\n",
    "fastbook.setup_book()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from fastbook import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CNN Interpretation with CAM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CAM and Hooks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = untar_data(URLs.PETS)/'images'\n",
    "def is_cat(x): return x[0].isupper()\n",
    "dls = ImageDataLoaders.from_name_func(\n",
    "    path, get_image_files(path), valid_pct=0.2, seed=21,\n",
    "    label_func=is_cat, item_tfms=Resize(224))\n",
    "learn = cnn_learner(dls, resnet34, metrics=error_rate)\n",
    "learn.fine_tune(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img = PILImage.create(image_cat())\n",
    "x, = first(dls.test_dl([img]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Hook():\n",
    "    def hook_func(self, m, i, o): self.stored = o.detach().clone()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hook_output = Hook()\n",
    "hook = learn.model[0].register_forward_hook(hook_output.hook_func)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad(): output = learn.model.eval()(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "act = hook_output.stored[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "F.softmax(output, dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls.vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cam_map = torch.einsum('ck,kij->cij', learn.model[1][-1].weight, act)\n",
    "cam_map.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_dec = TensorImage(dls.train.decode((x,))[0][0])\n",
    "_,ax = plt.subplots()\n",
    "x_dec.show(ctx=ax)\n",
    "ax.imshow(cam_map[1].detach().cpu(), alpha=0.6, extent=(0,224,224,0),\n",
    "              interpolation='bilinear', cmap='magma');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hook.remove()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Hook():\n",
    "    def __init__(self, m):\n",
    "        self.hook = m.register_forward_hook(self.hook_func)   \n",
    "    def hook_func(self, m, i, o): self.stored = o.detach().clone()\n",
    "    def __enter__(self, *args): return self\n",
    "    def __exit__(self, *args): self.hook.remove()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with Hook(learn.model[0]) as hook:\n",
    "    with torch.no_grad(): output = learn.model.eval()(x.cuda())\n",
    "    act = hook.stored"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Gradient CAM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class HookBwd():\n",
    "    def __init__(self, m):\n",
    "        self.hook = m.register_backward_hook(self.hook_func)   \n",
    "    def hook_func(self, m, gi, go): self.stored = go[0].detach().clone()\n",
    "    def __enter__(self, *args): return self\n",
    "    def __exit__(self, *args): self.hook.remove()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cls = 1\n",
    "with HookBwd(learn.model[0]) as hookg:\n",
    "    with Hook(learn.model[0]) as hook:\n",
    "        output = learn.model.eval()(x.cuda())\n",
    "        act = hook.stored\n",
    "    output[0,cls].backward()\n",
    "    grad = hookg.stored"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "w = grad[0].mean(dim=[1,2], keepdim=True)\n",
    "cam_map = (w * act[0]).sum(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_,ax = plt.subplots()\n",
    "x_dec.show(ctx=ax)\n",
    "ax.imshow(cam_map.detach().cpu(), alpha=0.6, extent=(0,224,224,0),\n",
    "              interpolation='bilinear', cmap='magma');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with HookBwd(learn.model[0][-2]) as hookg:\n",
    "    with Hook(learn.model[0][-2]) as hook:\n",
    "        output = learn.model.eval()(x.cuda())\n",
    "        act = hook.stored\n",
    "    output[0,cls].backward()\n",
    "    grad = hookg.stored"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "w = grad[0].mean(dim=[1,2], keepdim=True)\n",
    "cam_map = (w * act[0]).sum(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_,ax = plt.subplots()\n",
    "x_dec.show(ctx=ax)\n",
    "ax.imshow(cam_map.detach().cpu(), alpha=0.6, extent=(0,224,224,0),\n",
    "              interpolation='bilinear', cmap='magma');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Questionnaire"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. What is a \"hook\" in PyTorch?\n",
    "1. Which layer does CAM use the outputs of?\n",
    "1. Why does CAM require a hook?\n",
    "1. Look at the source code of the `ActivationStats` class and see how it uses hooks.\n",
    "1. Write a hook that stores the activations of a given layer in a model (without peeking, if possible).\n",
    "1. Why do we call `eval` before getting the activations? Why do we use `no_grad`?\n",
    "1. Use `torch.einsum` to compute the \"dog\" or \"cat\" score of each of the locations in the last activation of the body of the model.\n",
    "1. How do you check which order the categories are in (i.e., the correspondence of index->category)?\n",
    "1. Why are we using `decode` when displaying the input image?\n",
    "1. What is a \"context manager\"? What special methods need to be defined to create one?\n",
    "1. Why can't we use plain CAM for the inner layers of a network?\n",
    "1. Why do we need to register a hook on the backward pass in order to do Grad-CAM?\n",
    "1. Why can't we call `output.backward()` when `output` is a rank-2 tensor of output activations per image per class?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Further Research"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. Try removing `keepdim` and see what happens. Look up this parameter in the PyTorch docs. Why do we need it in this notebook?\n",
    "1. Create a notebook like this one, but for NLP, and use it to find which words in a movie review are most significant in assessing the sentiment of a particular movie review."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
